package com.zhengbo.simplerpc.client;

import com.zhengbo.simplerpc.common.MessageInput;
import com.zhengbo.simplerpc.common.MessageOutput;
import com.zhengbo.simplerpc.common.MessageRegistry;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;

/**
 * Created by zhengbo on 2019/8/21.
 */
@Slf4j
@ChannelHandler.Sharable
public class ClientMessageCollector extends ChannelInboundHandlerAdapter {

    private MessageRegistry registry;

    private RpcClient client;

    private ChannelHandlerContext context;

    private ConcurrentMap<String, RpcFuture<?>> pendingTasks = new ConcurrentHashMap<>();

    private Throwable connectionClosed = new Exception("rpc connection not active error");

    public ClientMessageCollector(MessageRegistry registry, RpcClient rpcClient) {
        this.registry = registry;
        this.client = rpcClient;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {

        this.context = ctx;
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {

        this.context = null;
        pendingTasks.forEach((key, value) -> {

            value.fail(connectionClosed);
        });

        pendingTasks.clear();

        ctx.channel().eventLoop().schedule(() -> {
            this.client.reconnect();
        }, 1, TimeUnit.SECONDS);

    }

    public <T> RpcFuture<T> send(MessageOutput messageOutput) {

        ChannelHandlerContext ctx = context;

        RpcFuture<T> future = new RpcFuture<>();

        if (ctx != null) {
            ctx.channel().eventLoop().execute(() -> {
                pendingTasks.put(messageOutput.getRequestId(), future);
                ctx.writeAndFlush(messageOutput);
            });
        } else {
            future.fail(connectionClosed);
        }

        return future;
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (!(msg instanceof MessageInput)) {
            return;
        }

        MessageInput input = (MessageInput) msg;
        Class<?> clazz = registry.get(input.getType());

        if (clazz == null) {
            log.error("无法识别的消息类型:{}", input.getType());
            return;
        }

        Object payLoad = input.getPayLoad(clazz);

        RpcFuture<Object> rpcFuture = (RpcFuture<Object>) pendingTasks.get(input.getRequestId());
        if (rpcFuture == null) {
            log.error("future not found with type:{}", input.getType());
            return;
        }
        rpcFuture.success(payLoad);
        pendingTasks.remove(input.getRequestId());
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {

        log.error("client message collector error,ctx:{},cause:{}", ctx, cause);
    }

    public void close() {
        if (context != null) {
            context.close();
        }
    }
}
